#include "opencl_source_map.hpp" 
namespace MNN { 
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* self_attention_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define DEAL_HEAD_DIM_NOT_ALIGN "" if(hd * 4 + 3 >= head_dim) {"" temp_0.w = (FLOAT)0;"" temp_1.w = (FLOAT)0;"" temp_2.w = (FLOAT)0;"" temp_3.w = (FLOAT)0;"" }"" if(hd * 4 + 2 >= head_dim) {"" temp_0.z = (FLOAT)0;"" temp_1.z = (FLOAT)0;"" temp_2.z = (FLOAT)0;"" temp_3.z = (FLOAT)0;"" }"" if(hd * 4 + 1 >= head_dim) {"" temp_0.y = (FLOAT)0;"" temp_1.y = (FLOAT)0;"" temp_2.y = (FLOAT)0;"" temp_3.y = (FLOAT)0;"" }\n"
"#define DEAL_SEQ_LEN_NOT_ALIGN "" if(4 * sl + 3 >= seq_len) {"" temp_3 = (FLOAT4)0;"" }"" if(4 * sl + 2 >= seq_len) {"" temp_2 = (FLOAT4)0;"" }"" if(4 * sl + 1 >= seq_len) {"" temp_1 = (FLOAT4)0;"" }\n"
"__kernel void split_transpose_qkv(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,// [Batch,seqLen/4,mNumHead*3*mHeadDim,4]\n"
" __global FLOAT *output_q,// [Batch*mNumHead,head_dim_pack_k,seq_len_pack_mn/qSeqSplitNum]\n"
" __global FLOAT *output_k,// [Batch*mNumHead,head_dim_pack_k,seq_len_pack_mn]\n"
" __global FLOAT *output_v,// [Batch*mNumHead,ROUND_UP(seqLen,tile),head_dim_pack_mn]\n"
" __private const int seq_len_pack_mn,\n"
" __private const int seq_len_piece,\n"
" __private const int head_dim_pack_mn,\n"
" __private const int head_dim_pack_k,\n"
" __private const int seq_len,\n"
" __private const int head_num,\n"
" __private const int head_dim,\n"
" __private const int batch,\n"
" __private const int seq_index\n"
") {\n"
" const int sl=get_global_id(0); // seqLen_4\n"
" const int hd=get_global_id(1); // mHeadDim_4\n"
" const int z=get_global_id(2); // Batch*mNumHead\n"
" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n"
" \n"
" const int b=z/head_num;\n"
" const int hn=z % head_num;\n"
" \n"
" const int offset_q=((b*head_num+hn)*head_dim_pack_k+4*hd)*seq_len_piece+4*sl;\n"
" if(seq_index>0) {\n"
" // fill output_q only\n"
" if(sl*4 >= seq_len || hd*4 >= head_dim) {\n"
" if(hd*4<head_dim_pack_k) {\n"
" if(sl*4<seq_len_piece) {\n"
" vstore4((FLOAT4)0,0,output_q+offset_q);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece+seq_len_piece);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece+seq_len_piece+seq_len_piece);\n"
" }\n"
" }\n"
" return;\n"
" }\n"
" \n"
" const int offset_inp=((((seq_index*seq_len_piece/4+sl)*batch+b)*head_num+hn)*3*head_dim+4*hd)*4;\n"
" if(sl*4<seq_len_piece) {\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+4);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+8);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+12);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_HEAD_DIM_NOT_ALIGN\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_SEQ_LEN_NOT_ALIGN\n"
" #endif\n"
" vstore4(temp_0,0,output_q+offset_q);\n"
" vstore4(temp_1,0,output_q+offset_q+seq_len_piece);\n"
" vstore4(temp_2,0,output_q+offset_q+seq_len_piece+seq_len_piece);\n"
" vstore4(temp_3,0,output_q+offset_q+seq_len_piece+seq_len_piece+seq_len_piece);\n"
" }\n"
" return;\n"
" }\n"
" const int offset_k=((b*head_num+hn)*head_dim_pack_k+4*hd)*seq_len_pack_mn+4*sl;\n"
" const int offset_v=((b*head_num+hn)*seq_len_pack_mn+4*sl)*head_dim_pack_mn+4*hd;\n"
" if(sl*4 >= seq_len || hd*4 >= head_dim) {\n"
" if(hd*4<head_dim_pack_k) {\n"
" if(sl*4<seq_len_piece) {\n"
" vstore4((FLOAT4)0,0,output_q+offset_q);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece+seq_len_piece);\n"
" vstore4((FLOAT4)0,0,output_q+offset_q+seq_len_piece+seq_len_piece+seq_len_piece);\n"
" }\n"
" vstore4((FLOAT4)0,0,output_k+offset_k);\n"
" vstore4((FLOAT4)0,0,output_k+offset_k+seq_len_pack_mn);\n"
" vstore4((FLOAT4)0,0,output_k+offset_k+seq_len_pack_mn+seq_len_pack_mn);\n"
" vstore4((FLOAT4)0,0,output_k+offset_k+seq_len_pack_mn+seq_len_pack_mn+seq_len_pack_mn);\n"
" }\n"
" vstore4((FLOAT4)0,0,output_v+offset_v);\n"
" vstore4((FLOAT4)0,0,output_v+offset_v+head_dim_pack_mn);\n"
" vstore4((FLOAT4)0,0,output_v+offset_v+head_dim_pack_mn+head_dim_pack_mn);\n"
" vstore4((FLOAT4)0,0,output_v+offset_v+head_dim_pack_mn+head_dim_pack_mn+head_dim_pack_mn);\n"
" \n"
" return;\n"
" }\n"
" \n"
" const int offset_inp=(((sl*batch+b)*head_num+hn)*3*head_dim+4*hd)*4;\n"
" \n"
" if(sl*4<seq_len_piece) {\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+4);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+8);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+12);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_HEAD_DIM_NOT_ALIGN\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_SEQ_LEN_NOT_ALIGN\n"
" #endif\n"
" vstore4(temp_0,0,output_q+offset_q);\n"
" vstore4(temp_1,0,output_q+offset_q+seq_len_piece);\n"
" vstore4(temp_2,0,output_q+offset_q+seq_len_piece+seq_len_piece);\n"
" vstore4(temp_3,0,output_q+offset_q+seq_len_piece+seq_len_piece+seq_len_piece);\n"
" }\n"
" \n"
" {\n"
" // K\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp+4*head_dim);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+4*head_dim+4);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+4*head_dim+8);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+4*head_dim+12);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_HEAD_DIM_NOT_ALIGN\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_SEQ_LEN_NOT_ALIGN\n"
" #endif\n"
" \n"
" vstore4(temp_0,0,output_k+offset_k);\n"
" vstore4(temp_1,0,output_k+offset_k+seq_len_pack_mn);\n"
" vstore4(temp_2,0,output_k+offset_k+seq_len_pack_mn+seq_len_pack_mn);\n"
" vstore4(temp_3,0,output_k+offset_k+seq_len_pack_mn+seq_len_pack_mn+seq_len_pack_mn);\n"
" \n"
" // V\n"
" temp_0=vload4(0,input+offset_inp+8*head_dim);\n"
" temp_1=vload4(0,input+offset_inp+8*head_dim+4);\n"
" temp_2=vload4(0,input+offset_inp+8*head_dim+8);\n"
" temp_3=vload4(0,input+offset_inp+8*head_dim+12);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_HEAD_DIM_NOT_ALIGN\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_SEQ_LEN_NOT_ALIGN\n"
" #endif\n"
" \n"
" vstore4((FLOAT4){temp_0.x,temp_1.x,temp_2.x,temp_3.x},0,output_v+offset_v);\n"
" vstore4((FLOAT4){temp_0.y,temp_1.y,temp_2.y,temp_3.y},0,output_v+offset_v+head_dim_pack_mn);\n"
" vstore4((FLOAT4){temp_0.z,temp_1.z,temp_2.z,temp_3.z},0,output_v+offset_v+head_dim_pack_mn+head_dim_pack_mn);\n"
" vstore4((FLOAT4){temp_0.w,temp_1.w,temp_2.w,temp_3.w},0,output_v+offset_v+head_dim_pack_mn+head_dim_pack_mn+head_dim_pack_mn);\n"
" }\n"
"}\n"
"#ifndef SOFTMAX_LOCAL_SIZE\n"
" #define SOFTMAX_LOCAL_SIZE 512\n"
"#endif\n"
"// [outside,axis,inside] -> reduce: inside\n"
"__kernel void softmax_inside(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,// [batch*mNumHead,ROUND_UP(seqLen,tile),ROUND_UP(seqLen,tile)]\n"
" __global FLOAT *output,\n"
" __private const int inside_len,\n"
" __private const int4 shape // [batch*mNumHead,ROUND_UP(seqLen,tile),ROUND_UP(seqLen,tile)]\n"
" ) {\n"
" const int inside=get_global_id(0);\n"
" const int axis=get_global_id(1);\n"
" const int outside=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(inside,axis,outside);\n"
" const int offset=(outside*shape.y+axis)*shape.z+0;\n"
"#if SOFTMAX_LOCAL_SIZE >= 4\n"
" int lid=get_local_id(0);\n"
" float local sum_mnn[SOFTMAX_LOCAL_SIZE];\n"
" float local max_mnn[SOFTMAX_LOCAL_SIZE];\n"
" /*Compute Max */\n"
" float maxValue=(float)(-FLT_MAX);\n"
" // clip to seq_len\n"
" for (int i=lid; i<inside_len; i+=SOFTMAX_LOCAL_SIZE) {\n"
" maxValue=fmax(maxValue,(float)input[offset+ i]);\n"
" }\n"
" max_mnn[lid]=maxValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" #pragma unroll\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i >>= 1){\n"
" if (lid<i)\n"
" max_mnn[lid]=fmax(max_mnn[lid],max_mnn[lid+i]);\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" maxValue=max_mnn[0];\n"
" /*Compute Exp Sum*/\n"
" float sumValue=0;\n"
" for (int i=lid; i<inside_len; i+=SOFTMAX_LOCAL_SIZE) {\n"
" sumValue += exp((float)input[offset+ i]-maxValue);\n"
" }\n"
" sum_mnn[lid]=sumValue;\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" #pragma unroll\n"
" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i >>= 1){\n"
" if (lid<i)\n"
" sum_mnn[lid]=sum_mnn[lid]+sum_mnn[lid+i];\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" sumValue=sum_mnn[0];\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" const int out_offset=(outside*shape.z+0)*shape.y+axis;\n"
" #endif\n"
" /*Compute Result */\n"
" for (int i=lid; i<inside_len; i+=SOFTMAX_LOCAL_SIZE) {\n"
" float value=exp((float)input[offset+ i]-maxValue)/sumValue;\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" output[out_offset+ i*shape.y]=value;\n"
" #else\n"
" output[offset+ i]=value;\n"
" #endif\n"
" }\n"
" if(shape.z>inside_len){\n"
" for(int i=lid+inside_len; i<shape.z; i+=SOFTMAX_LOCAL_SIZE){\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" output[out_offset+ i*shape.y]=(FLOAT)0;\n"
" #else\n"
" output[offset+ i]=(FLOAT)0;\n"
" #endif\n"
" }\n"
" }\n"
"#else\n"
" /*Compute Max */\n"
" float maxValue=(float)(-FLT_MAX);\n"
" // clip to seq_len\n"
" for (int i=0; i<inside_len; i++) {\n"
" maxValue=fmax(maxValue,(float)input[offset+ i]);\n"
" }\n"
" /*Compute Exp Sum*/\n"
" float sumValue=0;\n"
" for (int i=0; i<inside_len; i++) {\n"
" sumValue += exp((float)input[offset+ i]-maxValue);\n"
" }\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" const int out_offset=(outside*shape.z+0)*shape.y+axis;\n"
" #endif\n"
" /*Compute Result */\n"
" for (int i=0; i<inside_len; i++) {\n"
" float value=exp((float)input[offset+ i]-maxValue)/sumValue;\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" output[out_offset+ i*shape.y]=value;\n"
" #else\n"
" output[offset+ i]=value;\n"
" #endif\n"
" }\n"
" if(shape.z>inside_len){\n"
" for(int i=inside_len; i<shape.z; i++){\n"
" #ifdef OUTPUT_TRANSPOSE\n"
" output[out_offset+ i*shape.y]=(FLOAT)0;\n"
" #else\n"
" output[offset+ i]=(FLOAT)0;\n"
" #endif\n"
" }\n"
" }\n"
"#endif\n"
"}\n"
"// [N X Y4 4] -> [N Y X]\n"
"__kernel void trans_3d_buf(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT* input,\n"
" __global FLOAT* output,\n"
" __private const int batch,\n"
" __private const int width,\n"
" __private const int height\n"
") {\n"
" int b=get_global_id(2);\n"
" int w=get_global_id(0);\n"
" int h=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM3(w,h,b);\n"
" w=w << 3;\n"
" h=h << 3;\n"
" \n"
" const int inp_offset=(b*width+w)*height+h;\n"
" const int out_offset=(b*height+h)*width+w;\n"
" FLOAT8 value_0=vload8(0,input+inp_offset);\n"
" FLOAT8 value_1=vload8(0,input+inp_offset+height);\n"
" FLOAT8 value_2=vload8(0,input+inp_offset+height+height);\n"
" FLOAT8 value_3=vload8(0,input+inp_offset+height+height+height);\n"
" FLOAT8 value_4=vload8(0,input+inp_offset+(height << 2));\n"
" FLOAT8 value_5=vload8(0,input+inp_offset+height*5);\n"
" FLOAT8 value_6=vload8(0,input+inp_offset+height*6);\n"
" FLOAT8 value_7=vload8(0,input+inp_offset+height*7);\n"
" \n"
" vstore8((FLOAT8){value_0.s0,value_1.s0,value_2.s0,value_3.s0,value_4.s0,value_5.s0,value_6.s0,value_7.s0},0,output+out_offset);\n"
" vstore8((FLOAT8){value_0.s1,value_1.s1,value_2.s1,value_3.s1,value_4.s1,value_5.s1,value_6.s1,value_7.s1},0,output+out_offset+width);\n"
" vstore8((FLOAT8){value_0.s2,value_1.s2,value_2.s2,value_3.s2,value_4.s2,value_5.s2,value_6.s2,value_7.s2},0,output+out_offset+width+width);\n"
" vstore8((FLOAT8){value_0.s3,value_1.s3,value_2.s3,value_3.s3,value_4.s3,value_5.s3,value_6.s3,value_7.s3},0,output+out_offset+width+width+width);\n"
" vstore8((FLOAT8){value_0.s4,value_1.s4,value_2.s4,value_3.s4,value_4.s4,value_5.s4,value_6.s4,value_7.s4},0,output+out_offset+(width << 2));\n"
" vstore8((FLOAT8){value_0.s5,value_1.s5,value_2.s5,value_3.s5,value_4.s5,value_5.s5,value_6.s5,value_7.s5},0,output+out_offset+width*5);\n"
" vstore8((FLOAT8){value_0.s6,value_1.s6,value_2.s6,value_3.s6,value_4.s6,value_5.s6,value_6.s6,value_7.s6},0,output+out_offset+width*6);\n"
" vstore8((FLOAT8){value_0.s7,value_1.s7,value_2.s7,value_3.s7,value_4.s7,value_5.s7,value_6.s7,value_7.s7},0,output+out_offset+width*7);\n"
"}\n"
"__kernel void clip_transpose_qkv(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,// [Batch*mNumHead,ROUND_UP(mHeadDim,tile),ROUND_UP(seqLen,tile)]\n"
" __global FLOAT *output,// [Batch,seqLen/4,mNumHead*mHeadDim,4]\n"
" __private const int tile,\n"
" __private const int seq_len,\n"
" __private const int seq_len_piece,\n"
" __private const int head_num,\n"
" __private const int head_dim,\n"
" __private const int batch,\n"
" __private const int seq_index\n"
") {\n"
" \n"
" const int sl=get_global_id(0); // seqLen_Piece_4\n"
" const int hd=get_global_id(1); // mHeadDim_4\n"
" const int z=get_global_id(2); // Batch*mNumHead\n"
" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n"
" \n"
" const int b=z/head_num;\n"
" const int hn=z % head_num;\n"
" \n"
" const int seq_len_4=(seq_len+3)/4;\n"
" \n"
" if(seq_index*seq_len_piece/4+sl >= seq_len_4) {\n"
" return;\n"
" }\n"
" const int seq_len_pack=seq_len_piece;//((seq_len+tile-1)/tile)*tile;\n"
" const int head_dim_pack=((head_dim+tile-1)/tile)*tile;\n"
" \n"
" const int offset_inp=((b*head_num+hn)*head_dim_pack+4*hd)*seq_len_pack+4*sl;\n"
" \n"
" const int offset_out=((((seq_index*seq_len_piece/4+sl)*batch+b)*head_num+hn)*head_dim+4*hd)*4;\n"
" // Q\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+seq_len_pack);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+2*seq_len_pack);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+3*seq_len_pack);\n"
" \n"
" vstore4(temp_0,0,output+offset_out);\n"
" if(4*hd+1>head_dim) return;\n"
" vstore4(temp_1,0,output+offset_out+4);\n"
" if(4*hd+2>head_dim) return;\n"
" vstore4(temp_2,0,output+offset_out+8);\n"
" if(4*hd+3>head_dim) return;\n"
" vstore4(temp_3,0,output+offset_out+12);\n"
"}\n"
;
#endif
}
